{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/sharedata/mdy/env/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import triton\n",
    "import triton.language as tl\n",
    "import math\n",
    "from copy import deepcopy\n",
    "import os\n",
    "os.environ['TRITON_PRINT_AUTOTUNING'] = '1'\n",
    "# from nsa_attn import NsaAttention\n",
    "from flash_attn import flash_attn_func as fa2\n",
    "from flash_attn_interface import flash_attn_func as fa3\n",
    "from exp_family import compute_p, compute_select_p, _attention, fused_p, _cattention, select_for_bwd, select_for_fwd_bwd\n",
    "from select_attn import select_attn\n",
    "import random\n",
    "from fla.ops.nsa import parallel_nsa\n",
    "from pku_nsa import parallel_nsa_topk\n",
    "\n",
    "n = 1024 * 4\n",
    "kernel_size = 32\n",
    "stride = 16\n",
    "select_size = 64\n",
    "top_n = 16\n",
    "b, qh, kh, d, vd = 1, 64, 4, 128, 128\n",
    "sm_scale = d ** -0.5\n",
    "device = 'cuda'\n",
    "dtype = torch.bfloat16\n",
    "num_blocks = (n - kernel_size) // stride + 1\n",
    "q = torch.randn(b, n, qh, d, device=device, dtype=dtype)\n",
    "k = torch.randn(b, n, kh, d, device=device, dtype=dtype)\n",
    "v = torch.randn(b, n, kh, vd, device=device, dtype=dtype)\n",
    "ck = torch.randn(b, num_blocks, kh, d, device=device, dtype=dtype)\n",
    "cv = torch.randn(b, num_blocks, kh, vd, device=device, dtype=dtype)\n",
    "lse = torch.rand(b, qh, n, device=device, dtype=torch.float32) + 4"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# cmp_attn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True True\n",
      "3.978924512863159\n",
      "3.32159686088562\n"
     ]
    }
   ],
   "source": [
    "q.requires_grad_(True)\n",
    "ck.requires_grad_(True)\n",
    "cv.requires_grad_(True)\n",
    "q1 = deepcopy(q)\n",
    "ck1 = deepcopy(ck)\n",
    "cv1 = deepcopy(cv)\n",
    "y1, lse1 = _cattention.apply(q, ck, cv, kernel_size, stride, sm_scale, 1)\n",
    "y2, lse2 = _cattention.apply(q, ck, cv, kernel_size, stride, sm_scale, 2)\n",
    "dy = torch.rand_like(y1)\n",
    "print(torch.allclose(y1, y2), torch.allclose(lse1, lse2))\n",
    "\n",
    "print(triton.testing.do_bench(lambda: y1.backward(dy, retain_graph=True), grad_to_none=[q, ck, cv]))\n",
    "print(triton.testing.do_bench(lambda: y2.backward(dy, retain_graph=True), grad_to_none=[q1, ck1, cv1]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.3595395088195801\n",
      "0.3616417944431305\n",
      "0.428196519613266\n"
     ]
    }
   ],
   "source": [
    "print(triton.testing.do_bench(lambda: _cattention.apply(q, ck, cv, kernel_size, stride, sm_scale, 1)))\n",
    "print(triton.testing.do_bench(lambda: _cattention.apply(q, ck, cv, kernel_size, stride, sm_scale, 2)))\n",
    "print(triton.testing.do_bench(lambda: fa2(q, ck, cv, causal=False)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# compute_attn_p"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "11.327719688415527\n",
      "10.787739753723145\n"
     ]
    }
   ],
   "source": [
    "# print(triton.testing.do_bench(lambda: compute_p(q, ck, lse, kernel_size, stride, method=1)))\n",
    "# print(triton.testing.do_bench(lambda: compute_p(q, ck, lse, kernel_size, stride, method=2)))# \n",
    "print(triton.testing.do_bench(lambda: compute_p(q, ck, lse, kernel_size, stride, method=3)))\n",
    "print(triton.testing.do_bench(lambda: compute_p(q, ck, lse, kernel_size, stride, method=4)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# compute_select_indices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 4, 255, 4096])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "prob = compute_p(q, ck, lse, kernel_size, stride, method=4)\n",
    "prob2 = prob.transpose(-1, -2).contiguous()\n",
    "prob.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5.338936805725098\n",
      "2.9004673957824707\n",
      "2.299271583557129\n",
      "0.619713544845581\n"
     ]
    }
   ],
   "source": [
    "print(triton.testing.do_bench(lambda: compute_select_p(prob2, kernel_size, stride, select_size)))\n",
    "print(triton.testing.do_bench(lambda: compute_select_p(prob, kernel_size, stride, select_size, method=2)))\n",
    "print(triton.testing.do_bench(lambda: compute_select_p(prob, kernel_size, stride, select_size,top_n=16, method=3)))\n",
    "print(triton.testing.do_bench(lambda: compute_select_p(prob2, kernel_size, stride, select_size, top_n=16, method=4)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.02709120884537697\n"
     ]
    }
   ],
   "source": [
    "print(triton.testing.do_bench(lambda: select_for_fwd_bwd(prob2, kernel_size, stride, select_size, top_n=16)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True\n"
     ]
    }
   ],
   "source": [
    "p, fwd_ind, bwd_ind, count = select_for_fwd_bwd(prob2, kernel_size, stride, select_size, top_n=16)\n",
    "bwd_ind2, count2 = select_for_bwd(fwd_ind.to(torch.int64))\n",
    "print(torch.allclose(count, count2[..., :-1].to(torch.int32)))\n",
    "for _ in range(100):\n",
    "    b_idx = random.randint(0, b-1)\n",
    "    h_idx = random.randint(0, kh-1)\n",
    "    row_idx = random.randint(0, n//select_size - 1)\n",
    "    cnt = count[b_idx, h_idx, row_idx]\n",
    "    val1, ind1 = bwd_ind[b_idx, h_idx, row_idx, :cnt].sort(-1)\n",
    "    val2, ind2 = bwd_ind2[b_idx, h_idx, row_idx, :cnt].sort(-1)\n",
    "    assert torch.allclose(val1.int(), val2.int())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(triton.testing.do_bench(lambda: fused_p(q, ck, lse, kernel_size, stride, select_size, top_n=3, sm_scale=sm_scale, method=2)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(triton.testing.do_bench(lambda: fused_p(q, ck, lse, kernel_size, stride, select_size, top_n=3, sm_scale=sm_scale, method=4)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "9.823797225952148\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "p1, indices1 = compute_select_p(prob, kernel_size, stride, select_size, top_n=16, method=3)\n",
    "p2, indices2 = compute_select_p(prob2, kernel_size, stride, select_size,top_n=16, method=4, return_p=True)\n",
    "print(triton.testing.do_bench(lambda: select_for_bwd(indices2)))\n",
    "torch.allclose(p1.transpose(-1, -2), p2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(1., device='cuda:0')"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(indices2 == fwd_ind).sum() / indices2.numel()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.allclose(count, count2[0, :, :-1].to(torch.int32))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# select_attn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 4, 4096, 16])"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "prob = compute_p(q, ck, lse, kernel_size, stride, method=4).transpose(-1, -2).contiguous()\n",
    "_, indices = compute_select_p(prob, kernel_size, stride, select_size,top_n=top_n, method=4, return_p=False)\n",
    "indices2 = indices.transpose(1,2).contiguous()\n",
    "indices.shape\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Triton autotuning for function parallel_nsa_fwd_kernel finished after 0.88s; best config selected: num_warps: 1, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None;\n",
      "Triton autotuning for function _fwd_kernel1 finished after 2.78s; best config selected: num_warps: 2, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None;\n"
     ]
    }
   ],
   "source": [
    "y1 = parallel_nsa(q, k, v, indices2, select_size)\n",
    "y2 = _attention.apply(q, k, v, select_size, indices, sm_scale, 1)\n",
    "# y3 = _attention.apply(q, k, v, select_size, indices, sm_scale, 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Triton autotuning for function parallel_nsa_fwd_kernel finished after 0.91s; best config selected: num_warps: 1, num_ctas: 1, num_stages: 2, num_buffers_warp_spec: 0, num_consumer_groups: 0, reg_dec_producer: 0, reg_inc_consumer: 0, maxnreg: None;\n",
      "0.8710374236106873\n"
     ]
    }
   ],
   "source": [
    "print(triton.testing.do_bench(lambda: parallel_nsa(q, k, v, indices2, select_size)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.8875794410705566\n"
     ]
    }
   ],
   "source": [
    "print(triton.testing.do_bench(lambda: _attention.apply(q, k, v, select_size, indices, sm_scale, 1)))\n",
    "# print(triton.testing.do_bench(lambda: _attention.apply(q, k, v, select_size, indices, sm_scale, 2)))\n",
    "# print(triton.testing.do_bench(lambda: _attention.apply(q, k, v, select_size, indices, sm_scale, 3)))\n",
    "# print(triton.testing.do_bench(lambda: fa2(q, k[:, :select_size*top_n], v[:, :select_size*top_n])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "q.requires_grad_(True)\n",
    "k.requires_grad_(True)\n",
    "v.requires_grad_(True)\n",
    "q1 = deepcopy(q)\n",
    "k1 = deepcopy(k)\n",
    "v1 = deepcopy(v)\n",
    "y1 = _attention.apply(q, k, v, select_size, indices, sm_scale, 1)\n",
    "y2 = _attention.apply(q1, k1, v1, select_size, indices, sm_scale, 3)\n",
    "dy = torch.randn_like(y1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "16.976537704467773\n"
     ]
    }
   ],
   "source": [
    "print(triton.testing.do_bench(lambda: y1.backward(dy, retain_graph=True), grad_to_none=[q, k, v]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "18.112581253051758\n"
     ]
    }
   ],
   "source": [
    "print(triton.testing.do_bench(lambda: y1.backward(dy, retain_graph=True), grad_to_none=[q, k, v]))\n",
    "# print(triton.testing.do_bench(lambda: y2.backward(dy, retain_graph=True), grad_to_none=[q1, k1, v1]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-1.0469e+00,  2.6250e+00, -7.0801e-02,  ..., -4.1211e-01,\n",
       "         -1.7812e+00,  1.8457e-01],\n",
       "        [-1.0234e+00,  2.1094e+00,  8.7891e-02,  ...,  5.8594e-01,\n",
       "         -1.2812e+00,  1.0234e+00],\n",
       "        [-4.7852e-01, -6.1719e-01, -8.5547e-01,  ...,  8.4375e-01,\n",
       "         -3.7695e-01,  1.0391e+00],\n",
       "        ...,\n",
       "        [-4.9316e-02,  6.3477e-02,  3.4668e-02,  ..., -3.4668e-02,\n",
       "         -1.4282e-02, -2.5146e-02],\n",
       "        [ 8.8501e-03,  7.6599e-03, -5.8838e-02,  ..., -2.3071e-02,\n",
       "          1.0452e-03, -1.3855e-02],\n",
       "        [ 1.3855e-02,  2.4536e-02, -1.1230e-01,  ...,  2.9907e-03,\n",
       "         -4.4678e-02,  4.2236e-02]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<SelectBackward0>)"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y1[0, :, 0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-1.0469e+00,  2.6250e+00, -7.0801e-02,  ..., -4.1211e-01,\n",
       "         -1.7812e+00,  1.8457e-01],\n",
       "        [-1.0234e+00,  2.1094e+00,  8.7891e-02,  ...,  5.8594e-01,\n",
       "         -1.2812e+00,  1.0234e+00],\n",
       "        [-4.7852e-01, -6.1719e-01, -8.5547e-01,  ...,  8.4375e-01,\n",
       "         -3.7695e-01,  1.0391e+00],\n",
       "        ...,\n",
       "        [-4.9316e-02,  6.3477e-02,  3.4668e-02,  ..., -3.4668e-02,\n",
       "         -1.4282e-02, -2.5146e-02],\n",
       "        [ 8.8501e-03,  7.6599e-03, -5.8838e-02,  ..., -2.3071e-02,\n",
       "          1.0452e-03, -1.3855e-02],\n",
       "        [ 1.3855e-02,  2.4536e-02, -1.1230e-01,  ...,  2.9907e-03,\n",
       "         -4.4678e-02,  4.2236e-02]], device='cuda:0', dtype=torch.bfloat16,\n",
       "       grad_fn=<SelectBackward0>)"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y2[0, :, 0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
